{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Format DataFrame\n",
    "\n",
    "Be advised, this dataset (SKLearn's Forest Cover Types) can take a little while to download...\n",
    "\n",
    "This is a multi-class classification task, in which the target is label-encoded.\n",
    "\n",
    "We'll also subtract one from the targets, to make the seven labels fall within the range of 0-6, rather than the default range of 1-7. This is to keep CatBoost from complaining."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(581012, 55)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x_0</th>\n",
       "      <th>x_1</th>\n",
       "      <th>x_2</th>\n",
       "      <th>x_3</th>\n",
       "      <th>x_4</th>\n",
       "      <th>x_5</th>\n",
       "      <th>x_6</th>\n",
       "      <th>x_7</th>\n",
       "      <th>x_8</th>\n",
       "      <th>x_9</th>\n",
       "      <th>...</th>\n",
       "      <th>x_45</th>\n",
       "      <th>x_46</th>\n",
       "      <th>x_47</th>\n",
       "      <th>x_48</th>\n",
       "      <th>x_49</th>\n",
       "      <th>x_50</th>\n",
       "      <th>x_51</th>\n",
       "      <th>x_52</th>\n",
       "      <th>x_53</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3247.0</td>\n",
       "      <td>289.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>268.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>1624.0</td>\n",
       "      <td>186.0</td>\n",
       "      <td>238.0</td>\n",
       "      <td>193.0</td>\n",
       "      <td>2525.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3200.0</td>\n",
       "      <td>46.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>162.0</td>\n",
       "      <td>45.0</td>\n",
       "      <td>1592.0</td>\n",
       "      <td>223.0</td>\n",
       "      <td>200.0</td>\n",
       "      <td>105.0</td>\n",
       "      <td>2254.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2368.0</td>\n",
       "      <td>48.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>277.0</td>\n",
       "      <td>121.0</td>\n",
       "      <td>1260.0</td>\n",
       "      <td>224.0</td>\n",
       "      <td>196.0</td>\n",
       "      <td>99.0</td>\n",
       "      <td>1237.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2828.0</td>\n",
       "      <td>50.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>417.0</td>\n",
       "      <td>73.0</td>\n",
       "      <td>1252.0</td>\n",
       "      <td>225.0</td>\n",
       "      <td>215.0</td>\n",
       "      <td>123.0</td>\n",
       "      <td>962.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2932.0</td>\n",
       "      <td>32.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>618.0</td>\n",
       "      <td>55.0</td>\n",
       "      <td>638.0</td>\n",
       "      <td>218.0</td>\n",
       "      <td>217.0</td>\n",
       "      <td>134.0</td>\n",
       "      <td>1092.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 55 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      x_0    x_1   x_2    x_3    x_4     x_5    x_6    x_7    x_8     x_9 ...  \\\n",
       "0  3247.0  289.0  12.0  268.0   40.0  1624.0  186.0  238.0  193.0  2525.0 ...   \n",
       "1  3200.0   46.0  17.0  162.0   45.0  1592.0  223.0  200.0  105.0  2254.0 ...   \n",
       "2  2368.0   48.0  19.0  277.0  121.0  1260.0  224.0  196.0   99.0  1237.0 ...   \n",
       "3  2828.0   50.0  11.0  417.0   73.0  1252.0  225.0  215.0  123.0   962.0 ...   \n",
       "4  2932.0   32.0  11.0  618.0   55.0   638.0  218.0  217.0  134.0  1092.0 ...   \n",
       "\n",
       "   x_45  x_46  x_47  x_48  x_49  x_50  x_51  x_52  x_53  y  \n",
       "0   0.0   1.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0  0  \n",
       "1   0.0   1.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0  0  \n",
       "2   0.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0  2  \n",
       "3   0.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0  1  \n",
       "4   0.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0   0.0  0  \n",
       "\n",
       "[5 rows x 55 columns]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from sklearn.datasets import fetch_covtype\n",
    "\n",
    "data = fetch_covtype(shuffle=True, random_state=32)\n",
    "train_df = pd.DataFrame(data.data, columns=[\"x_{}\".format(_) for _ in range(data.data.shape[1])])\n",
    "train_df[\"y\"] = data.target - 1\n",
    "\n",
    "print(train_df.shape)\n",
    "train_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set Up Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cross-Experiment Key:   'S8Q6MRvEmEfVCMfR0XLC03UB2-5lkgXQPowUlVgqREs='\n"
     ]
    }
   ],
   "source": [
    "from hyperparameter_hunter import Environment, CVExperiment\n",
    "from sklearn.metrics import f1_score\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "env = Environment(\n",
    "    train_dataset=train_df,\n",
    "    results_path=\"HyperparameterHunterAssets\",\n",
    "    target_column=\"y\",\n",
    "    metrics=dict(f1=lambda y_true, y_pred: f1_score(y_true, y_pred, average=\"micro\")),\n",
    "    cv_type=KFold,\n",
    "    cv_params=dict(n_splits=5, random_state=32),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that HyperparameterHunter has an active `Environment`, we can do two things:\n",
    "\n",
    "# 1. Perform Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<21:15:02> Validated Environment:  'S8Q6MRvEmEfVCMfR0XLC03UB2-5lkgXQPowUlVgqREs='\n",
      "<21:15:02> Initialized Experiment: 'f29d973e-fb49-4044-a97f-e210dd87f0f1'\n",
      "<21:15:02> Hyperparameter Key:     '1b9sh7OR9oG66Hz0_L_kCWJLqKWeytMNgvGp_0JXLJ8='\n",
      "<21:15:02> \n",
      "<21:15:52> F0.0 AVG:   OOF(f1=0.72561)  |  Time Elapsed: 49.73871 s\n",
      "<21:16:41> F0.1 AVG:   OOF(f1=0.72578)  |  Time Elapsed: 49.28626 s\n",
      "<21:17:31> F0.2 AVG:   OOF(f1=0.72496)  |  Time Elapsed: 49.28679 s\n",
      "<21:18:20> F0.3 AVG:   OOF(f1=0.72663)  |  Time Elapsed: 49.53078 s\n",
      "<21:19:09> F0.4 AVG:   OOF(f1=0.72581)  |  Time Elapsed: 49.38508 s\n",
      "<21:19:10> \n",
      "<21:19:10> FINAL:    OOF(f1=0.72576)  |  Time Elapsed: 4.0 m, 7.4541 s\n",
      "<21:19:10> \n",
      "<21:19:10> Saving results for Experiment: 'f29d973e-fb49-4044-a97f-e210dd87f0f1'\n"
     ]
    }
   ],
   "source": [
    "from catboost import CatBoostClassifier\n",
    "\n",
    "experiment = CVExperiment(\n",
    "    model_initializer=CatBoostClassifier,\n",
    "    model_init_params=dict(\n",
    "        iterations=100,\n",
    "        learning_rate=0.03,\n",
    "        depth=6,\n",
    "        save_snapshot=False,\n",
    "        allow_writing_files=False,\n",
    "        loss_function=\"MultiClass\",\n",
    "        classes_count=7,\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validated Environment with key: \"S8Q6MRvEmEfVCMfR0XLC03UB2-5lkgXQPowUlVgqREs=\"\n",
      "\u001b[31mSaved Result Files\u001b[0m\n",
      "\u001b[31m______________________________________________________________________\u001b[0m\n",
      " Step |       ID |   Time |      Value |     depth |   learning_rate | \n",
      "Experiments matching cross-experiment key/algorithm: 1\n",
      "Experiments fitting in the given space: 1\n",
      "Experiments matching current guidelines: 1\n",
      "    0 | f29d973e | 00m00s | \u001b[35m   0.72576\u001b[0m | \u001b[32m        6\u001b[0m | \u001b[32m         0.0300\u001b[0m | \n",
      "\u001b[31mHyperparameter Optimization\u001b[0m\n",
      "\u001b[31m______________________________________________________________________\u001b[0m\n",
      " Step |       ID |   Time |      Value |     depth |   learning_rate | \n",
      "    1 | 14757f77 | 04m53s | \u001b[35m   0.88707\u001b[0m | \u001b[32m        9\u001b[0m | \u001b[32m         0.4295\u001b[0m | \n",
      "    2 | 9aa18f5d | 07m16s | \u001b[35m   0.89922\u001b[0m | \u001b[32m       12\u001b[0m | \u001b[32m         0.2106\u001b[0m | \n",
      "    3 | 014c6375 | 06m04s | \u001b[35m   0.90904\u001b[0m | \u001b[32m       11\u001b[0m | \u001b[32m         0.3770\u001b[0m | \n",
      "    4 | 756d3178 | 04m16s |    0.85418 |         7 |          0.4081 | \n",
      "    5 | 916550d7 | 09m36s | \u001b[35m   0.91320\u001b[0m | \u001b[32m       13\u001b[0m | \u001b[32m         0.2280\u001b[0m | \n",
      "    6 | 6d85a3bf | 03m51s |    0.80467 |         5 |          0.3046 | \n",
      "    7 | 367ad1ff | 04m01s |    0.83565 |         6 |          0.4135 | \n",
      "    8 | dddbae8d | 13m50s |    0.91296 |        14 |          0.1727 | \n",
      "Optimization loop completed in 0:53:50.941711\n",
      "Best score was 0.9132031696419353 from Experiment \"916550d7-bfea-4ba4-8a2f-3843cc5c5962\"\n"
     ]
    }
   ],
   "source": [
    "from hyperparameter_hunter import GBRT, Real, Integer, Categorical\n",
    "\n",
    "optimizer = GBRT(iterations=8, random_state=42)\n",
    "\n",
    "optimizer.forge_experiment(\n",
    "    model_initializer=CatBoostClassifier,\n",
    "    model_init_params=dict(\n",
    "        iterations=100,\n",
    "        learning_rate=Real(low=0.0001, high=0.5),\n",
    "        depth=Integer(4, 15),\n",
    "        save_snapshot=False,\n",
    "        allow_writing_files=False,\n",
    "        loss_function=\"MultiClass\",\n",
    "        classes_count=7,\n",
    "    ),\n",
    ")\n",
    "\n",
    "optimizer.go()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice, `optimizer` recognizes our earlier `experiment`'s hyperparameters fit inside the search space/guidelines set for `optimizer`.\n",
    "\n",
    "Then, when optimization is started, it automatically learns from `experiment`'s results - without any extra work for us!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
